Skip to content

[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-set-defaults-true
Open

[NNX] NNX migration (11/N): set pure_nnx / enable_nnx / pure_nnx_decoder defaults to True#3526
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-set-defaults-true

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX + post-training bugfixes that surfaced once the NNX path was exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
  5. ✅ NNX correctness fixes, feature enablements, vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + remaining checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. ✅ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix. (PR [NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix #3844)
  10. ✅ Vocab tiling custom_vjp for NNX. (PR [NNX] NNX migration prep (10/N): vocab tiling custom_vjp with output-head carve-out #3849)
  11. 🔄 [This PR] Flip enable_nnx, pure_nnx, pure_nnx_decoder from FalseTrue in base.yml. Bundle the NNX-only fixes that surface once pure_nnx=True.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

PR6–PR10 promoted every routed-to-Linen feature to NNX-native; #2885 added NNX-native pipeline parallelism; #4040 added Qwix on NNX. This PR flips the three defaults in base.yml and bundles the NNX-only fixes that surface once pure_nnx=True.

Changes

src/maxtext/configs/base.yml — flip defaults

  • enable_nnx: False → True, pure_nnx: False → True, pure_nnx_decoder: False → True.

src/maxtext/utils/sharding.py — Zero-1 on flat nnx.State

  • New build_zero1_input_state_mesh_shardings overlays Param-leaf shardings on the flat nnx.State. The Linen call state_mesh_shardings.replace(params=...) only works on TrainState.

src/maxtext/trainers/pre_train/train.py, train_compile.py — NNX dispatch

  • AOT compile dispatches to the Zero-1 NNX builder when pure_nnx=True.
  • Pops Intermediate sown variables before grad so MTP auxiliary losses aren't differentiated as part of the main loss.

src/maxtext/trainers/diloco/diloco.py, src/maxtext/common/checkpointing.py — DiLoCo under NNX

  • DiLoCoTrainState.merge / .split use nnx.split, guarded against double-merging.
  • maybe_save_checkpoint reads state.step under enable_diloco, otherwise state.optimizer.step.
  • replace_nnx_model_params identifies "model" leaves by path via tree_flatten_with_path. Preserves the original treedef so lax.cond branches still match, and is robust to future key additions to inner_state (addresses #3526 review comment).

src/maxtext/utils/generate_param_only_checkpoint.py — NNX param-only restore

  • Pure-dict restore ({"value": ...} wrapping), opt_state path skipping, bf16 cast skipping rng leaves.

src/maxtext/inference/maxengine/maxengine.py — drop Linen-vs-NNX parity asserts

  • NNX-only prefill/decode/cache assertions stay.

src/maxtext/layers/nnx_wrappers.py — modernize ToLinen.__call__

  • Drop the _refresh_variable_trace_state private-state workaround; use idiomatic nnx.split / nnx.update / nnx.merge and filter unknown paths before assignment (addresses #3526 review comment).

src/maxtext/layers/quantizations.py — OSS qwix import fix from #4040

  • from qwix._src.utils import flax_utilfrom qwix._src import flax_util. PR#4040 (Copybara) referenced the Google-internal qwix._src.utils path; OSS qwix has flax_util directly under _src/. Unblocks OSS CI.

src/maxtext/utils/{muon_utils,qk_clip_utils,train_utils}.py — NNX-shape adjustments

  • muon_utils.get_muon_weight_dimension_numbers dispatches by NNX-vs-Linen state shape.
  • qk_clip_utils broadcasts over the correct axis under NNX.
  • train_utils.jit_train_step threads dropout_rng=None on the NNX path.

src/maxtext/trainers/post_train/sft/train_sft_native.py — SFT NNX path

  • nnx.split(state) before jit, dropout-rng threaded conditionally (mirrors the pre-train path).

src/maxtext/checkpoint_conversion/{to_maxtext.py, utils/utils.py} — shared helper

  • Route path-key parsing through the shared param_key_parts_from_path helper; to_maxtext.py also handles nn.LogicallyPartitioned shape access correctly.

Tests

  • tests/unit/tiling_test.py::LossAndGradientCorrectnessTest — pin to Linen in setUp (builds via transformer_as_linen); drop 6 stale pytest.skip("vocab tiling on NNX") guards (now NNX-native via PR10).
  • tests/integration/maxengine_test.py — drop Linen-vs-NNX prefill/decode parity tests; NNX-only assertions kept.
  • tests/unit/max_utils_test.py — pin UnscanTest to Linen via init_pyconfig; drop the three hasattr(state, "model") branches (addresses #3526 review comment).
  • tests/integration/diloco_test.py — NNX training-loop simulation + checkpoint coverage.
  • tests/integration/generate_param_only_checkpoint_test.py — NNX param-only restore coverage.
  • tests/unit/{muon_utils,maxtext_utils,optimizers,state_dtypes,train_state_nnx_checkpoint}_test.py — adjusted for TrainStateNNX / flat nnx.State shapes.
  • tests/unit/qwen3_next_vs_reference_test.pyepsepsilon rename in Qwen3NextRMSNorm_PT + related cleanup.
  • tests/integration/tokamax_test.py — split parameterized test into gmm_bf16 and gmm_fp8 cases.

Stats

  • Diff: +591 / −445 across 27 files.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 5 times, most recently from bac289f to db75887 Compare April 6, 2026 21:09
@ecnal-cienet ecnal-cienet changed the title Feat/nnx set defaults true NNX migration prep (5/N): enable NNX by default Apr 6, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 17 times, most recently from 5a7f63b to 73213e0 Compare April 9, 2026 23:47
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (5/N): enable NNX by default NNX migration prep (6/N): enable NNX by default Apr 16, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (6/N): enable NNX by default NNX migration prep (5/N): enable NNX by default Apr 20, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 4 times, most recently from 7e33a09 to 2f34cfb Compare April 28, 2026 14:16
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/nnx-set-defaults-true branch from f4674bb to b7d1f6d Compare May 22, 2026 10:29
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 2 times, most recently from 9d05b96 to 450ef8d Compare May 25, 2026 15:26
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/nnx-set-defaults-true branch 5 times, most recently from e420909 to 8a27207 Compare May 26, 2026 09:27
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 10 times, most recently from 2d3b8a6 to 52219b5 Compare May 28, 2026 20:08
@github-actions

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request successfully transitions MaxText to use JAX's NNX API by default by flipping enable_nnx, pure_nnx, and pure_nnx_decoder to True. The changes are comprehensive and robust, addressing flat NNX state checkpointing, DiLoCo training, parameter-only restoration, and pinning Linen-coupled integration/unit tests.

🔍 General Feedback

  • High-Quality Code Migration: The migration of core features (including DiLoCo, param-only generation, and maxengine) to support NNX-native behavior is very well-structured and thoroughly covered by updated tests.
  • Robust Sharding Alignments: Adopting build_zero1_input_state_mesh_shardings to overlay Param-leaf shardings on the flat nnx.State ensures seamless compatibility with ZeRO-1 optimizers under NNX.
  • Preserved Parity: Pinning complex pipeline parallelism and fp8/sparsity tests to the Linen path is a pragmatic decision that avoids regressions while these features are migrated in subsequent iterations.

Comment thread src/maxtext/trainers/diloco/diloco.py
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxtext/configs/types.py Outdated
PR6-PR10 promoted every routed-to-Linen feature to NNX-native; PR#2885 lands NNX-native pipeline parallelism. This PR flips the three defaults in base.yml so NNX is the production path, and bundles the NNX-only fixes that surface once pure_nnx=True (DiLoCo merge/checkpoint, Zero-1 input shardings on flat nnx.State, MTP sown-Variable handling, generate_param_only_checkpoint NNX flow, maxengine Linen-parity removal).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants